import asyncio
import argparse
import json
import os
import re
import time
from collections import Counter
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
from datetime import datetime
import aiohttp  # Using aiohttp instead of openai

# ====================== Configuration ======================
@dataclass
class Config:
    dataset_path: str = "OpenBookQA.json"
    model: str = "your model"  
    base_url: str = "your base_url" 
    num_paths: int = 5  

# ====================== LLM Client ======================
class LLMClient:
    """Wrapper for LLM API using aiohttp"""
    def __init__(self, config: Config):
        self.config = config
        self.token_counts = [0, 0] 
    
    async def generate(self, prompt: str) -> str:
        """Generate response from LLM using aiohttp"""
        try:
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.config.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.3, 
                    "max_tokens": 8000,
                    "top_p": 0.8
                }
                
                async with session.post(
                    f"{self.config.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()
                    
                    input_tokens = len(prompt) // 4
                    output_tokens = len(resp["choices"][0]["message"]["content"]) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    return resp["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise

# ====================== Core Solver ======================
class OpenBookQASolver:
    """Chain-of-Thought Self-Consistency OpenBookQA Solver"""
    def __init__(self):
        self.config = Config()
        self.llm = LLMClient(self.config)
        self.stats = {
            "total_problems": 0,
            "correct_answers": 0,
            "accuracy": 0.0,
            "tokens_used": [0, 0]
        }
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text with multiple patterns"""
        patterns = [
            r'\\boxed\{([A-D])\}',                          # \boxed{A} pattern
            r'Answer:\s*([A-D])',                            # Answer: A pattern
            r'Final Answer:\s*([A-D])',                      # Final Answer: A
            r'Correct Answer:\s*([A-D])',                    # Correct Answer: A
            r'(?:The )?[Cc]orrect option is ([A-D])',       # The correct option is A
            r'(?:The )?[Bb]est answer is ([A-D])',           # The best answer is A
            r'\b([A-D])\b(?!\.\w)(?=[^a-zA-Z]*$)',          # Standalone letter at end
            r'\(([A-D])\)',                                  # (A) in parentheses
            r'\[([A-D])\]',                                  # [A] in brackets
            r'\{([A-D])\}',                                  # {A} in curly braces
            r'<([A-D])>',                                    # <A> in angle brackets
        ]
        
        for pattern in patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                return match.group(1).upper()
        
        # Fallback: look for the last occurrence of A-D in the text
        last_option_match = re.findall(r'([A-D])', text)
        if last_option_match:
            return last_option_match[-1].upper()
        
        return None

    def _verify_answer(self, problem: Dict[str, Any], selected_answer: str) -> bool:
        """Verify if selected answer matches correct option"""
        correct_answer = problem.get("answerKey", "").strip().upper()
        return selected_answer == correct_answer
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load OpenBookQA problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    async def solve_problem(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Solve a problem using Chain-of-Thought Self-Consistency approach"""
        question = problem["question_stem"]
        choices = problem["choices"]["text"]
        options = "\n".join([f"{chr(65+i)}. {choices[i]}" for i in range(4)])  # A, B, C, D
        
        prompt = f""" 
Question: {question}
Options:
{options}
Let's think step by step to solve the question, give the correct answer by stating "The correct answer is [X]" where [X] is exactly one letter (A, B, C, or D)."""

        # Generate multiple reasoning paths
        responses = []
        answers = []
        for _ in range(self.config.num_paths):
            response = await self.llm.generate(prompt)
            answer = self._extract_answer(response)
            if answer:
                responses.append(response)
                answers.append(answer)
        
        # Select the most frequent answer
        final_answer = None
        if answers:
            answer_counts = Counter(answers)
            final_answer = answer_counts.most_common(1)[0][0]
        
        is_correct = self._verify_answer(problem, final_answer) if final_answer else False
        
        # Update statistics
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        self.stats["accuracy"] = (self.stats["correct_answers"] / self.stats["total_problems"] * 100) if self.stats["total_problems"] > 0 else 0
        self.stats["tokens_used"] = self.llm.token_counts.copy()
        
        return {
            "question": question,
            "options": {
                "A": choices[0],
                "B": choices[1],
                "C": choices[2],
                "D": choices[3]
            },
            "correct_answer": problem.get("answerKey", "").strip().upper(),
            "responses": responses,
            "answers": answers,
            "final_answer": final_answer,
            "is_correct": is_correct,
            "tokens_used": self.llm.token_counts.copy()
        }

# ====================== Main Execution ======================
async def main():
    parser = argparse.ArgumentParser(description="Chain-of-Thought Self-Consistency OpenBookQA Problem Solver")
    parser.add_argument("--start", type=int, default=0, help="Start index in dataset")
    parser.add_argument("--end", type=int, default=5, help="End index in dataset")
    args = parser.parse_args()
    
    os.makedirs("log/OpenBookQA_cot_sc", exist_ok=True)
    solver = OpenBookQASolver()
    problems = await solver.load_problems(args.start, args.end)
    
    all_results = []
    for idx, problem in enumerate(problems):
        print(f"\n{'='*50}\nProcessing problem {idx}: {problem['question_stem'][:50]}...\n{'='*50}")
        
        result = await solver.solve_problem(problem)
        all_results.append(result)
        
        # Print results for this problem
        print(f"Question: {result['question'][:100]}...")
        print(f"Generated Answers: {result.get('answers', [])}")
        print(f"Selected Answer (voting): {result.get('final_answer', '?')}")
        print(f"Correct Answer: {result['correct_answer']}")
        print(f"Correct: {result.get('is_correct', False)}")
        print(f"Tokens used: Input={result['tokens_used'][0]}, Output={result['tokens_used'][1]}")
    
    # Save final results
    if all_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"log/OpenBookQA_cot_sc/results_{args.start}_{args.end}_acc{solver.stats['accuracy']:.2f}%.json"
        
        with open(filename, "w", encoding="utf-8") as f:
            json.dump({
                "results": all_results,
                "stats": solver.stats
            }, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*50}")
        print(f"Results saved to {filename}")
        print(f"Total Problems: {solver.stats['total_problems']}")
        print(f"Correct Answers: {solver.stats['correct_answers']}")
        print(f"Accuracy: {solver.stats['accuracy']:.2f}%")
        print(f"Total Tokens Used: Input={solver.stats['tokens_used'][0]}, Output={solver.stats['tokens_used'][1]}")
        print(f"{'='*50}")

if __name__ == "__main__":
    asyncio.run(main())